[JAX] Remove GSPMD tests + adding guards and warning msg for GSPMD rules#2702
[JAX] Remove GSPMD tests + adding guards and warning msg for GSPMD rules#2702phu0ngng merged 2 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR deprecates GSPMD sharding propagation in the JAX backend of TransformerEngine in preparation for its full removal in June 2026. It removes all GSPMD-specific test variants across distributed test suites and example encoders, strips the
Key changes:
Notable concern: The deprecation warning in Confidence Score: 4/5
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[register_primitive called at import time] --> B{JAX version\n≤ 0.9.1?}
B -- Yes\n_JAX_GSPMD_SUPPORTED=True --> C{cls defines\ninfer_sharding_from_operands\nin own __dict__?}
C -- Yes --> D[_warn_gspmd_deprecation_once\nDeprecationWarning issued\nat import time]
D --> E[gspmd_kwargs = infer_sharding_from_operands\nfrom subclass]
C -- No --> F[gspmd_kwargs = infer_sharding_from_operands\nfrom BasePrimitive base — raises NotImplementedError]
E --> G[def_partition with partition +\nshardy_sharding_rule + infer_sharding_from_operands]
F --> G
B -- No\n_JAX_GSPMD_SUPPORTED=False --> H[gspmd_kwargs = empty dict]
H --> I[def_partition with only\npartition + shardy_sharding_rule]
G --> J[Primitive registered]
I --> J
style D fill:#ffcc00,color:#000
style F fill:#ff9999,color:#000
Last reviewed commit: 69a71c6 |
|
/te-ci JAX L1 |
|
/te-ci JAX L1 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM once CI finishes, thanks for making this change!
This full removal is valid as users will be in one of the following situations, right?
a) The user is on the latest JAX version, in which case having these GSPMD functions around could cause errors since GSPMD has been removed, so removing this logic as we do in this PR is correct
b) The user is on an older version of JAX, in which case they can use GSPMD or Shardy. There have been updates from JAX itself about this transition so users have been aware they need to move to Shardy by March 2026, so in our case removing it should still be okay
c) The user is on a very old version of JAX (e.g. >1 year old), in which Shardy doesn't work or has bugs. In this case, they might also have other compatibility issues with TE/JAX, in which case they should update to a more recent JAX version
|
Per offline discussion, we decided that GSPMD sharding propagation rules will be kept until June 2026. This PR only removes the GSPMD tests. |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L1 |
Additional Comments (2)
Since all current public JAX releases are in the This is purely forward-looking, which is fine and the comment explains the intent. However, it's worth explicitly noting that until JAX reaches 0.9.2+, all registered primitives will still have Consider either:
The guard Instead, they'll silently receive the |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Can you update the PR description with the latest info? The PR itself LGTM, thanks!
Additional Comments (1)
The PR description states "GSPMD rules with the existing primitives will work with older JAX versions (until 0.9.1) with a printed warning", which implies the warning should fire when GSPMD is actually invoked, not merely because GSPMD support code is present. A more accurate approach would be to emit the warning inside a wrapper around the GSPMD-callable that is passed to As written, a user who has already migrated to Shardy on an older JAX release will see a misleading "Use it at your own risk" warning even though they are not using GSPMD. |
Description
GSPMD sharding propagation is being deprecated in favour of Shardy, which is now the default JAX partitioner.
This commit removes all GSPMD-related tests. The GSPMD sharding propagation rules will be kept for another 3 months until June 2026.
GSPMD rules with the existing primitives will work with older JAX versions (until 0.9.1) with a printed warning.
For the incoming primitives that do not have the GSPMD rules, if users attempt to use them with GSPMD, an error will be raised before it crashes.
Type of change
Checklist: